import argparse
import time

def parse_args():
    parser = argparse.ArgumentParser(description="Run Julia branch and bound with customizable parameters.")
    parser.add_argument('--k', type = int, default = 3, help='Specify the parameter k (default: 3)')
    parser.add_argument('--time_limit', type=int, default = 3600, help='Specify the time limit (in seconds, default: 800)')
    parser.add_argument('--filepath', type = str, default = 'Matrix_CovColon_txt', help = 'File path of the dataset')
    parser.add_argument('--d_max', type = int, default = 30, help = 'maximum size of the block allowed')
    parser.add_argument('--tol', type = float, default = 0.01, help = 'value for the tolerant, the actual tolerant will be tol * infnity norm of A')
    return parser.parse_args()



import numpy as np
from scipy.linalg import sqrtm


import random

def sum_of_squares(elements):
    return sum(x**2 for x in elements)

def calculate_median_of_sums(matrix, alpha):
    d = matrix.shape[0]  # dimension of the matrix
    num_lists = int(d**alpha)  # calculate d^alpha and convert to int
    
    # Flatten the matrix and shuffle its elements
    elements = list(matrix.flatten())
    random.shuffle(elements)
    
    # Split elements into num_lists parts
    lists = [elements[i::num_lists] for i in range(num_lists)]
    
    # Calculate sum of squares for each list
    mean_sum_squares = [sum_of_squares(lst) / len(lst) for lst in lists]
    
    # Calculate and return the median of sum_squares
    return np.median(mean_sum_squares)

def threshold_matrix(matrix, threshold):
    """ Thresholds the matrix by setting values below the threshold (in absolute terms) to zero. """
    return np.where(np.abs(matrix) < threshold, 0, matrix)

def find_supp(x, ind):
    # Identify the non-zero elements of best_x
    non_zero_indices = np.nonzero(x)[0]
    
    # Intersect the indices where best_x is non-zero with best_ind
    support_indices = [ind[index] for index in non_zero_indices]
    
    return support_indices

class UnionFind:
    def __init__(self, size):
        self.parent = list(range(size))

    def find(self, i):
        if self.parent[i] != i:
            self.parent[i] = self.find(self.parent[i])
        return self.parent[i]

    def union(self, i, j):
        root_i = self.find(i)
        root_j = self.find(j)
        if root_i != root_j:
            self.parent[root_j] = root_i


def find_block_diagonals(A, matrix):
    def dfs(node, visited, component):
        stack = [node]
        while stack:
            v = stack.pop()
            if not visited[v]:
                visited[v] = True
                component.append(v)
                for neighbor in range(d):
                    if matrix[v, neighbor] != 0 and not visited[neighbor]:
                        stack.append(neighbor)

    d = matrix.shape[1]
    visited = [False] * d
    components = []

    for i in range(d):
        if not visited[i]:
            component = []
            dfs(i, visited, component)
            components.append(component)

    buckets = {}
    for component in components:
        if component:
            root = component[0]
            buckets[root] = component

    block_diagonals = []
    indices = []
    d_star = 0

    for bucket in buckets.values():
        matrix_bucket = A[np.ix_(bucket, bucket)]
        if np.all(matrix_bucket == 0):
            continue
        block_diagonals.append(matrix_bucket)
        indices.append(bucket)
        d_star = max(d_star, len(bucket))

    return block_diagonals, indices, d_star



def solve_bd_spca(A, k, block_diagonals, indices, time_limit):
    # The input A is the original
    # Block_diagonals are the list of blocks in thresholded A
    # indices are the list of indices corresponding to the original A
    time_total = 0

    obj_best = 0
    ind_best = [0]
    x_best = np.zeros((1,1))
    

    for i in range(len(block_diagonals)):
        # print(f"This is the {i}-th block diagonal matrix.")
        bd = block_diagonals[i]
        if np.all(bd == 0):
            continue
        if bd.shape[0] < k:
            start_time = time.time()
            eigenvalues, eigenvectors = np.linalg.eig(bd)
            # Find the index of the maximum eigenvalue
            index_max_eigenvalue = np.argmax(eigenvalues)
            # Eigenvalue
            max_eigenvalue = eigenvalues[index_max_eigenvalue]
            # Corresponding eigenvector
            max_eigenvector = eigenvectors[:, index_max_eigenvalue]
            if max_eigenvalue > obj_best:
                obj_best = max_eigenvalue
                ind_best = indices[i]
                x_best = max_eigenvector
            block_time = time.time() - start_time
            # print(f"The runtime for this block is {block_time}s")
            time_total = time_total + block_time
            continue
        # Convert numpy arrays to Julia arrays
        Main.Sigma = Main.eval("Array{Float64}")(bd)
        Main.data = Main.eval("Array{Float64}")(np.real(sqrtm(bd)))
        # Create an instance of the problem struct
        Main.eval('prob = problem(data, Sigma)')
        
        start_time = time.time()
        try:
            results = Main.eval(f"branchAndBound(prob, {k}, timeCap={time_limit})")
            obj, xVal, timetoBound, timetoConverge, timeOut, explored, toPrint, finalGap = results
            if np.real(obj) > obj_best:
                obj_best = np.real(obj)
                ind_best = indices[i]
                x_best = xVal
        except Exception as e:
            print("An error occurred:", str(e))
        block_time = time.time() - start_time
        # print(f"The runtime for this block is {block_time}s")
        time_total = time_total + block_time
        
    start_time = time.time()
    original_obj = x_best.T @ A[np.ix_(ind_best, ind_best)] @ x_best
    time_total = time_total + time.time() - start_time

    # print(f"Best opt found is {obj_best}.")
    # print(f"original approximation of opt found is {original_obj}.")
    # print(f"Total runtime is {time_total}.")
    
    return x_best, ind_best, obj_best, original_obj, time_total

def solve_bd_spca_bs(A, k, initial_threshold, a = 0.1, b = 10, max_d = 100, tol = 5e-2, time_limit = 600):
    # In this function, we call solve_bd_spca many times
    start_time = time.time()
    total_time = 0
    U = b * initial_threshold
    L = a * initial_threshold
    
    S = threshold_matrix(A, U)
    block_diagonals, indices, d_star = find_block_diagonals(A, S)
    
    print(f"The soring time is {time.time() - start_time}, initial maximum size is {d_star}")
    
    print("Now solving the first spca instance.")
    best_x, best_ind, temp_obj, best_obj, time_passed = solve_bd_spca(A, k, block_diagonals, indices, max(time_limit - time.time() + start_time, 0))
    print(f"Solving time for the first SPCA is {time_passed}")
    
    i = 2
    # record the max d_star visited within computational constraints
    max_d_star = 0
    
    output_sol = np.zeros(A.shape[0])
    output_obj = 0
    output_ind = None
    output_eps = 0
    
    while U - L > tol:
        if max_d_star == max_d:
            # meaning that we have already reached the computational limit in previous attempt
            # We should stop immediately
            break
        
        if d_star <= max_d:
            max_d_star = max(d_star, max_d_star)
        # d_star_old = d_star
        best_obj_old = best_obj
        
        M = (U + L) / 2
        start_sorting_time = time.time()
        S = threshold_matrix(A, M)
        block_diagonals, indices, d_star = find_block_diagonals(A, S)
        
        if d_star <= max_d_star:
            # results are the same or worse
            # it should be d_star >= d_star_old
            U = M
            print(f"Current threshold is {M}, and gives the same d_star.\n")
            continue
        
        if d_star > max_d:
            # We cannot afford such computation
            L = M
            print(f"Current threshold is {M}, d_star is {d_star}, and exceeds computational resource.\n")
            continue
        
        print(f"Now running block diagonal spca for threshold {M}, with d_star being {d_star}.")
        print(f"This is the {i}-th spca instance.")
        i = i + 1
        
        sorting_time = time.time() - start_sorting_time
        
        # Else, we know that we can afford the computation, and the result is going to be potentially better
        best_x, best_ind, temp_obj, best_obj, time_passed = solve_bd_spca(A, k, block_diagonals, indices, time_limit)
        print(f"Best obj found is {best_obj}. The runtime for this instance is {time_passed + sorting_time}.")
        print(f"Best index set found is {best_ind}.")
        
        supp = find_supp(best_x, best_ind)
        print(f'Current support is {supp}.')
        if supp:
            D1, V1 = np.linalg.eigh(A[supp][:,supp]);
            y1 = V1[:, -1];
            better_PC_value = y1.T @ A[supp][:,supp] @ y1;
            print(f"Better obj found is {better_PC_value}.")
            if better_PC_value > output_obj:
                output_obj = better_PC_value
                output_sol = y1
                output_ind = best_ind
                output_eps = M
        
        total_time = time.time() - start_time
        print(f"The total runtime is {total_time}.")
        
        if total_time >= time_limit:
            print("Time limit reached.\n")
            break
        
        if abs(best_obj - best_obj_old) < 1e-2:
            print("Unchanged objective value detected.")
            U = M
        print("\n")
    
    total_time = time.time() - start_time
    print(f"Best opt found is {best_obj}.")
    print(f"Total runtime is {total_time}.")
    print("\n")
    
    return output_sol, output_ind, temp_obj, output_obj, total_time, output_eps
    


args = parse_args()



import os
# put the path to julia here. It looks like xxxx/yyyy/julia-1.6.7/bin/julia
julia_path = ''
exists = os.path.exists(julia_path)
executable = os.access(julia_path, os.X_OK)

print("Julia path exists:", exists)
print("Julia is executable:", executable)

if not exists or not executable:
    print("Check the path to Julia or permissions.")
else:
    from julia.api import Julia
    jl = Julia(runtime=julia_path, compiled_modules=False)

time_start_loading = time.time()
print("Importing Main...")
from julia import Main
print("Including utilities.jl...")
Main.include("utilities.jl")
print("Including branchAndBound.jl...")
Main.include("branchAndBound.jl")
time_loading = time.time() - time_start_loading
print(f"Loading time for julia package(s) is {time_loading}s.")




file_path = args.filepath

A = np.genfromtxt(file_path, delimiter=',')
n, d = A.shape

if n != d:
    # we are not getting a square matrix, thus transformation is needed
    A = np.cov(A.T)

sigma_np = A
data_np = np.real(sqrtm(sigma_np))


print("Data loaded from txt and processed.")

print("Now starting to solve the original spca.")
# Convert numpy arrays to Julia arrays
Main.data = Main.eval("Array{Float64}")(data_np)
Main.Sigma = Main.eval("Array{Float64}")(sigma_np)


# Create an instance of the problem struct
Main.eval('prob = problem(data, Sigma)')

print('Data loaded to juliapy.')

print('\n\n\n')

k = args.k
time_limit = args.time_limit

try:
    results = Main.eval(f"branchAndBound(prob, {k}, timeCap={time_limit})")
    spca_obj, spca_xVal, spca_timetoBound, spca_timetoConverge,\
        spca_timeOut, explored, toPrint, spca_finalGap = results
    # Process the results here
except Exception as e:
    print("An error occurred:", str(e))
    
# Print or process the results
print("Objective value:", spca_obj)
# print("Variable values:", xVal)
print("Time to bound:", spca_timetoBound)
print("Time to converge:", spca_timetoConverge)
print("Timeout:", spca_timeOut)
print("Nodes explored:", explored)
# print("Output to print:", toPrint)
print("Final gap:", spca_finalGap)

print('\n\n\n')

print("Now running block diagonal method:")
    
# estimation_time = time.time() 
# var = calculate_median_of_sums(A, 1.5)
# print(f"Variance estimation is {var}")
# estimation_time = time.time() - estimation_time

# initial_threshold = np.sqrt(var) * np.log(d)
initial_threshold = 1

infnorm = np.max(np.abs(A))

best_x, best_ind, temp_obj, best_obj, total_time, current_threshold \
    = solve_bd_spca_bs(A, k, initial_threshold, a = 0, b = infnorm, max_d = args.d_max, tol = args.tol * infnorm, time_limit=time_limit)

# total_time = total_time + estimation_time

print(f"Best obj found is {best_obj}.")
print(f"Instance size is {len(best_ind)}.")
print(f"Total runtime is {total_time}.")
print(f"Best threshold found is {current_threshold}.")

    
    






